#!/usr/bin/env python
# -*- coding: utf-8 -*-

""" By Martin Senande-Rivera
    For Spatial and temporal expansion of global wildland fire activity in response to climate change """

import os
import glob, sys
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from mpl_toolkits.basemap import Basemap, maskoceans
import matplotlib.colors as mcolors
import datetime as dt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import ListedColormap
import cmaps
from pandas import Series, DataFrame
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


font = {'size'   : 5}
matplotlib.rc('font', **font)

ruta_fc_val='../Classification/3-Validation/'
ruta_fc_p='../Classification/2-Classification/'
ruta_fc_f='../Classification/4-Future_Classification/'

FC_val = xr.open_dataset(ruta_fc_val+'FC_val.nc')['__xarray_dataarray_variable__']
FC_p = xr.open_dataset(ruta_fc_p+'FC.nc')['FC']
FC_f = xr.open_dataset(ruta_fc_f+'FC.nc')['FC']
Confidence = xr.open_dataset(ruta_fc_f+'FC_confidence.nc')['FC_confidence']

m = Basemap(epsg=4326, resolution='i',	
				llcrnrlat=-90., urcrnrlat=90.,	
				llcrnrlon=-180., urcrnrlon=180.)
lats = FC_p['lat']
lons = FC_p['lon']
x, y = np.meshgrid(lons.values, lats.values)




fig = plt.figure(figsize=(6,9))
gs = fig.add_gridspec(3, 1)
ax1 = fig.add_subplot(gs[0,0])
m.drawcoastlines(linewidth=0.25,zorder=3)
m.drawcountries(linewidth=0.1,zorder=3)
m.drawlsmask(land_color='none',ocean_color="#e4e4e4",lakes=True,zorder=2)
col_dict={0:'w',
          5:'k',
          6:'grey',
          11:'#145305',
          12:'#218607',
          13:'#8ade75',
          21:'#cb6c00',
          22:'#ffc021',
          23:'#ffdb87',
          31:'#9f007f',
          32:'#ff6bcb',
          33:'#f3c4ff',
          41:'#063555',
          42:'#0e84d2',
          43:'#81c0e9'}
cm = ListedColormap([col_dict[x] for x in col_dict.keys()])
labels = np.array(['BA=0ha | NC', 'BA=0ha | C', 'BA>0ha | NC', 'Tr-ds-r', 'Tr-ds-o', 'Tr-ds-i', 'Ar-fl-r', 'Ar-fl-o', 'Ar-fl-i', 'Te-dhs-r', 'Te-dhs-o', 'Te-dhs-i', 'Bo-hs-r', 'Bo-hs-o', 'Bo-hs-i'])
len_lab = len(labels)
norm_bins = np.sort([*col_dict.keys()]) + 0.5
norm_bins = np.insert(norm_bins, 0, np.min(norm_bins) - 1.0)
norm = matplotlib.colors.BoundaryNorm(norm_bins, len_lab, clip=True)
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[0], alpha=1., label=labels[0])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[5], alpha=1., label=labels[1])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[6], alpha=1., label=labels[2])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[11], alpha=1., label=labels[3])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[12], alpha=1., label=labels[4])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[13], alpha=1., label=labels[5])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[21], alpha=1., label=labels[6])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[22], alpha=1., label=labels[7])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[23], alpha=1., label=labels[8])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[31], alpha=1., label=labels[9])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[32], alpha=1., label=labels[10])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[33], alpha=1., label=labels[11])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[41], alpha=1., label=labels[12])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[42], alpha=1., label=labels[13])
ax1.fill_between([0.,0.], [0.,0.], color=col_dict[43], alpha=1., label=labels[14])
ax1.legend(loc='lower center', bbox_to_anchor=(0.5, -0.1),fontsize=5,
        ncol=5, fancybox=True, framealpha=1.)
cs2 = m.pcolormesh(x, y, FC_val, cmap=cm, norm=norm)
ax1.text(0.01, 0.99, 'a',family='sans-serif',weight='bold',size=7, horizontalalignment='left', verticalalignment='top',transform=ax1.transAxes)
plt.gcf()

ax2 = fig.add_subplot(gs[1,0])
m.drawcoastlines(linewidth=0.25,zorder=3)
m.drawcountries(linewidth=0.1,zorder=3)
m.drawlsmask(land_color='none',ocean_color="#e4e4e4",lakes=True,zorder=2)
col_dict={0:'w',
          11:'#145305',
          12:'#218607',
          13:'#8ade75',
          21:'#cb6c00',
          22:'#ffc021',
          23:'#ffdb87',
          31:'#9f007f',
          32:'#ff6bcb',
          33:'#f3c4ff',
          41:'#063555',
          42:'#0e84d2',
          43:'#81c0e9'}
cm = ListedColormap([col_dict[x] for x in col_dict.keys()])
labels = np.array(['No fires', 'Tr-ds-r', 'Tr-ds-o', 'Tr-ds-i', 'Ar-fl-r', 'Ar-fl-o', 'Ar-fl-i', 'Te-dhs-r', 'Te-dhs-o', 'Te-dhs-i', 'Bo-hs-r', 'Bo-hs-o', 'Bo-hs-i'])
len_lab = len(labels)
norm_bins = np.sort([*col_dict.keys()]) + 0.5
norm_bins = np.insert(norm_bins, 0, np.min(norm_bins) - 1.0)
norm = matplotlib.colors.BoundaryNorm(norm_bins, len_lab, clip=True)
ax2.fill_between([0.,0.], [0.,0.], color=col_dict[11], alpha=1., label=labels[1])
ax2.fill_between([0.,0.], [0.,0.], color=col_dict[12], alpha=1., label=labels[2])
ax2.fill_between([0.,0.], [0.,0.], color=col_dict[13], alpha=1., label=labels[3])
ax2.fill_between([0.,0.], [0.,0.], color=col_dict[21], alpha=1., label=labels[4])
ax2.fill_between([0.,0.], [0.,0.], color=col_dict[22], alpha=1., label=labels[5])
ax2.fill_between([0.,0.], [0.,0.], color=col_dict[23], alpha=1., label=labels[6])
ax2.fill_between([0.,0.], [0.,0.], color=col_dict[31], alpha=1., label=labels[7])
ax2.fill_between([0.,0.], [0.,0.], color=col_dict[32], alpha=1., label=labels[8])
ax2.fill_between([0.,0.], [0.,0.], color=col_dict[33], alpha=1., label=labels[9])
ax2.fill_between([0.,0.], [0.,0.], color=col_dict[41], alpha=1., label=labels[10])
ax2.fill_between([0.,0.], [0.,0.], color=col_dict[42], alpha=1., label=labels[11])
ax2.fill_between([0.,0.], [0.,0.], color=col_dict[43], alpha=1., label=labels[12])
ax2.legend(loc='lower center', bbox_to_anchor=(0.5, -0.1),fontsize=5,
        ncol=4, fancybox=False, framealpha=1.)
cs2 = m.pcolormesh(x, y, FC_p, cmap=cm, norm=norm)
ax2.text(0.01, 0.99, 'b',family='sans-serif',weight='bold',size=7, horizontalalignment='left', verticalalignment='top',transform=ax2.transAxes)
plt.gcf()

ax3 = fig.add_subplot(gs[2,0])
m.drawcoastlines(linewidth=0.25,zorder=3)
m.drawcountries(linewidth=0.1,zorder=3)
m.drawlsmask(land_color='none',ocean_color="#e4e4e4",lakes=True,zorder=2)
col_dict={0:'w',
          11:'#145305',
          12:'#218607',
          13:'#8ade75',
          21:'#cb6c00',
          22:'#ffc021',
          23:'#ffdb87',
          31:'#9f007f',
          32:'#ff6bcb',
          33:'#f3c4ff',
          41:'#063555',
          42:'#0e84d2',
          43:'#81c0e9'}
cm = ListedColormap([col_dict[x] for x in col_dict.keys()])
labels = np.array(['No fires', 'Tr-ds-r', 'Tr-ds-o', 'Tr-ds-i', 'Ar-fl-r', 'Ar-fl-o', 'Ar-fl-i', 'Te-dhs-r', 'Te-dhs-o', 'Te-dhs-i', 'Bo-hs-r', 'Bo-hs-o', 'Bo-hs-i'])
len_lab = len(labels)
norm_bins = np.sort([*col_dict.keys()]) + 0.5
norm_bins = np.insert(norm_bins, 0, np.min(norm_bins) - 1.0)
norm = matplotlib.colors.BoundaryNorm(norm_bins, len_lab, clip=True)
ax3.fill_between([0.,0.], [0.,0.], color=col_dict[11], alpha=1., label=labels[1])
ax3.fill_between([0.,0.], [0.,0.], color=col_dict[12], alpha=1., label=labels[2])
ax3.fill_between([0.,0.], [0.,0.], color=col_dict[13], alpha=1., label=labels[3])
ax3.fill_between([0.,0.], [0.,0.], color=col_dict[21], alpha=1., label=labels[4])
ax3.fill_between([0.,0.], [0.,0.], color=col_dict[22], alpha=1., label=labels[5])
ax3.fill_between([0.,0.], [0.,0.], color=col_dict[23], alpha=1., label=labels[6])
ax3.fill_between([0.,0.], [0.,0.], color=col_dict[31], alpha=1., label=labels[7])
ax3.fill_between([0.,0.], [0.,0.], color=col_dict[32], alpha=1., label=labels[8])
ax3.fill_between([0.,0.], [0.,0.], color=col_dict[33], alpha=1., label=labels[9])
ax3.fill_between([0.,0.], [0.,0.], color=col_dict[41], alpha=1., label=labels[10])
ax3.fill_between([0.,0.], [0.,0.], color=col_dict[42], alpha=1., label=labels[11])
ax3.fill_between([0.,0.], [0.,0.], color=col_dict[43], alpha=1., label=labels[12])
ax3.legend(loc='lower center', bbox_to_anchor=(0.5, -0.1),fontsize=5,
        ncol=4, fancybox=True, framealpha=1.)
cs2 = m.pcolormesh(x, y, FC_f, cmap=cm, norm=norm)
Confidence = xr.where(Confidence<75.,1.,Confidence)
Confidence = xr.where(Confidence>=75.,np.nan,Confidence)
cs3 = m.contourf(x, y, Confidence, np.arange(0.,1.01,0.01), cmap=plt.get_cmap('binary'), alpha=0.5)
ax3.text(0.01, 0.99, 'c',family='sans-serif',weight='bold',size=7, horizontalalignment='left', verticalalignment='top',transform=ax3.transAxes)
plt.gcf()

plt.subplots_adjust(top = 0.95, bottom=0.05, hspace=0., wspace=0.)

fig.savefig('Figure2.png',dpi=300, bbox_inches='tight') 						
#fig.savefig('Figure2.pdf', bbox_inches='tight')                         
